import torch
from torch import nn
import torch.nn.functional as F
from model.Transformer import Transformer
import model.resnet as models
from model.PSPNet import OneModel as PSPNet
from einops import rearrange

# add
import clip
import math
from model.get_cam import get_img_cam
from pytorch_grad_cam import GradCAM
from clip.clip_text import new_class_names, new_class_names_coco,BACKGROUND_CATEGORY_EACH,BACKGROUND_CATEGORY_COCO_EACH
from .segformer_head import SegFormerHead
from .PAR import PAR
import numpy as np
from .PFENet import PFENet

def zeroshot_classifier(classnames, templates, model):
    device = "cpu"
    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).to(device) #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights.t()


def zeroshot_classifier_bg(classnames, templates, model):
    zeroshot_weights = []
    for classname in classnames:
        zeroshot_weights.append(zeroshot_classifier(classname,templates,model))
    return zeroshot_weights

def Weighted_GAP(supp_feat, mask):
    supp_feat = supp_feat * mask
    feat_h, feat_w = supp_feat.shape[-2:][0], supp_feat.shape[-2:][1]
    area = F.avg_pool2d(mask, (supp_feat.size()[2], supp_feat.size()[3])) * feat_h * feat_w + 0.0005
    supp_feat = F.avg_pool2d(input=supp_feat, kernel_size=supp_feat.shape[-2:]) * feat_h * feat_w / area
    return supp_feat


def get_similarity(q, s, mask):
    if len(mask.shape) == 3:
        mask = mask.unsqueeze(1)
    mask = F.interpolate((mask == 1).float(), q.shape[-2:])
    cosine_eps = 1e-7
    s = s * mask
    bsize, ch_sz, sp_sz, _ = q.size()[:]
    tmp_query = q
    tmp_query = tmp_query.contiguous().view(bsize, ch_sz, -1)
    tmp_query_norm = torch.norm(tmp_query, 2, 1, True)
    tmp_supp = s
    tmp_supp = tmp_supp.contiguous().view(bsize, ch_sz, -1).contiguous()
    tmp_supp = tmp_supp.contiguous().permute(0, 2, 1).contiguous()
    tmp_supp_norm = torch.norm(tmp_supp, 2, 2, True)
    similarity = torch.bmm(tmp_supp, tmp_query) / (torch.bmm(tmp_supp_norm, tmp_query_norm) + cosine_eps)
    similarity = similarity.max(1)[0].view(bsize, sp_sz * sp_sz)
    similarity = similarity.view(bsize, 1, sp_sz, sp_sz)
    return similarity


def get_gram_matrix(fea):
    b, c, h, w = fea.shape
    fea = fea.reshape(b, c, h * w)  # C*N
    fea_T = fea.permute(0, 2, 1)  # N*C
    fea_norm = fea.norm(2, 2, True)
    fea_T_norm = fea_T.norm(2, 1, True)
    gram = torch.bmm(fea, fea_T) / (torch.bmm(fea_norm, fea_T_norm) + 1e-7)  # C*C
    return gram


def get_vgg16_layer(model):
    layer0_idx = range(0, 7)
    layer1_idx = range(7, 14)
    layer2_idx = range(14, 24)
    layer3_idx = range(24, 34)
    layer4_idx = range(34, 43)
    layers_0 = []
    layers_1 = []
    layers_2 = []
    layers_3 = []
    layers_4 = []
    for idx in layer0_idx:
        layers_0 += [model.features[idx]]
    for idx in layer1_idx:
        layers_1 += [model.features[idx]]
    for idx in layer2_idx:
        layers_2 += [model.features[idx]]
    for idx in layer3_idx:
        layers_3 += [model.features[idx]]
    for idx in layer4_idx:
        layers_4 += [model.features[idx]]
    layer0 = nn.Sequential(*layers_0)
    layer1 = nn.Sequential(*layers_1)
    layer2 = nn.Sequential(*layers_2)
    layer3 = nn.Sequential(*layers_3)
    layer4 = nn.Sequential(*layers_4)
    return layer0, layer1, layer2, layer3, layer4


def reshape_transform(tensor, height=28, width=28):
    tensor = tensor.permute(1, 0, 2)
    result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

class OneModel(nn.Module):
    def __init__(self, args, cls_type=None):
        super(OneModel, self).__init__()
        self.pfenet = PFENet()
        # add
        self.dataset = args.data_set
        self.shot = 1
        
        self.annotation_root = args.annotation_root
        self.clip_model, _ = clip.load(args.clip_path,"cpu")
        if self.dataset == 'pascal':
            self.bg_text_features = zeroshot_classifier_bg(BACKGROUND_CATEGORY_EACH, ['a clean origami {}.'],
                                                        self.clip_model)
            self.fg_text_features = zeroshot_classifier(new_class_names, ['a clean origami {}.'],
                                                        self.clip_model)
        elif self.dataset == 'coco':
            self.bg_text_features = zeroshot_classifier_bg(BACKGROUND_CATEGORY_COCO_EACH, ['a clean origami {}.'],
                                                        self.clip_model)
            self.fg_text_features = zeroshot_classifier(new_class_names_coco, ['a clean origami {}.'],
                                                        self.clip_model)
        self.decoder_fts_fuse = SegFormerHead(in_channels=768,embedding_dim=256,
                                              index=12)
        self.par = PAR(num_iter=20, dilations=[1,2,4,8,12,24])
        
    def forward(self, x, x_cv2, que_name, class_name, y_m=None, y_b=None, s_x=None, s_y=None, cat_idx=None):
        s_x_pf = s_x
        mask = rearrange(s_y, "b n h w -> (b n) 1 h w")
        mask = (mask == 1).float()
        h, w = x.shape[-2:]
        s_x = rearrange(s_x, "b n c h w -> (b n) c h w")
        # self.iter_num += 1

        # extract the cnn features
        _, _, query_feat_2, query_feat_3, query_feat_4, query_feat_5 = self.extract_feats(x)
        
        # extract the clip features
        if mask is not None:
            tmp_mask = F.interpolate(mask, size=x.shape[-2], mode='nearest')
            s_x_mask = s_x * tmp_mask
        tmp_supp_clip_fts, supp_attn_maps = self.clip_model.encode_image(s_x_mask, h, w, extract=True)[:]
        tmp_que_clip_fts, que_attn_maps = self.clip_model.encode_image(x, h, w, extract=True)[:]

        supp_clip_fts = [ss[1:, :, :] for ss in tmp_supp_clip_fts]
        que_clip_fts = [ss[1:, :, :] for ss in tmp_que_clip_fts]

        tmp_supp_clip_feat_all = [ss.permute(1, 2, 0) for ss in supp_clip_fts]
        supp_clip_feat_all = [aw.reshape(
            tmp_supp_clip_feat_all[0].shape[0], tmp_supp_clip_feat_all[0].shape[1], int(math.sqrt(tmp_supp_clip_feat_all[0].shape[2])),
            int(math.sqrt(tmp_supp_clip_feat_all[0].shape[2]))).float()
            for aw in tmp_supp_clip_feat_all]

        tmp_que_clip_feat_all = [qq.permute(1, 2, 0) for qq in que_clip_fts]
        que_clip_feat_all = [aw.reshape(
            tmp_que_clip_feat_all[0].shape[0], tmp_que_clip_feat_all[0].shape[1], int(math.sqrt(tmp_que_clip_feat_all[0].shape[2])),
            int(math.sqrt(tmp_que_clip_feat_all[0].shape[2]))).float()
            for aw in tmp_que_clip_feat_all]
        
        #全特征混合
        attn_fts = self.decoder_fts_fuse(torch.stack(que_clip_feat_all,dim=0))
        f_b, f_c, f_h, f_w = attn_fts.shape
        attn_fts_flatten = attn_fts.reshape(f_b, f_c, f_h*f_w)
        attn_pred = attn_fts_flatten.transpose(2, 1).bmm(attn_fts_flatten)
        attn_pred = torch.sigmoid(attn_pred) #(4,400,400)
        
        #get the vvp
        if self.shot == 1:
            similarity2 = get_similarity(que_clip_feat_all[10], supp_clip_feat_all[10], s_y)
            similarity1 = get_similarity(que_clip_feat_all[11], supp_clip_feat_all[11], s_y)
        else:
            mask = rearrange(mask, "(b n) c h w -> b n c h w", n=self.shot)
            supp_clip_feat_all = [rearrange(ss, "(b n) c h w -> b n c h w", n=self.shot) for ss in supp_clip_feat_all]
            clip_similarity_1 = [get_similarity(que_clip_feat_all[11], supp_clip_feat_all[11][:, i, ...], mask=mask[:, i, ...]) for i in
                           range(self.shot)]
            clip_similarity_2 = [get_similarity(que_clip_feat_all[10], supp_clip_feat_all[10][:, i, ...], mask=mask[:, i, ...]) for i in
                           range(self.shot)]
            mask = rearrange(mask, "b n c h w -> (b n) c h w")
            similarity1 = torch.cat(clip_similarity_1, dim=1)
            similarity2 = torch.cat(clip_similarity_2, dim=1)
        clip_similarity = torch.cat([similarity1, similarity2], dim=1).cuda()
        clip_similarity = F.interpolate(clip_similarity, size=(query_feat_3.shape[2], query_feat_3.shape[3]), mode='bilinear', align_corners=True)

        # get the vtp
        target_layers = [self.clip_model.visual.transformer.resblocks[-1].ln_1]
        cam = GradCAM(model=self.clip_model, target_layers=target_layers, reshape_transform=reshape_transform)
        img_cam_list,img_cam_list_PI = get_img_cam(x_cv2, que_name, class_name, self.clip_model, self.bg_text_features, self.fg_text_features, cam, self.annotation_root, self.training, attn_pred)
        img_cam_list = [F.interpolate(t_img_cam.unsqueeze(0).unsqueeze(0), size=(query_feat_3.shape[2], query_feat_3.shape[3]), mode='bilinear',
                                      align_corners=True) for t_img_cam in img_cam_list]
        img_cam_list_PI = [F.interpolate(t_img_cam.unsqueeze(0).unsqueeze(0), size=(query_feat_3.shape[2], query_feat_3.shape[3]), mode='bilinear',
                                      align_corners=True) for t_img_cam in img_cam_list_PI]
        img_cam = torch.cat(img_cam_list, 0)
        img_cam_PI = torch.cat(img_cam_list_PI, 0)
        # if self.iter_num > 15000 or (not self.training): #15000
        gen_fake_mask = img_cam.squeeze(1)
        # else:
        #     gen_fake_mask = img_cam_PI.squeeze(1)
        img_cam = torch.cat([img_cam,img_cam_PI],1)
        
        # 监督伪掩码
        cam_list = []
        for f,fake_mask in  enumerate(gen_fake_mask):
            fake_mask = fake_mask.unsqueeze(0)
            bg_score = torch.pow(1 - torch.max(fake_mask, dim=0, keepdims=True)[0], 1).cuda()
            fake_mask = torch.cat([bg_score, fake_mask], dim=0).cuda()
            
            valid_key = np.pad([1], (1, 0), mode='constant')
            valid_key = torch.from_numpy(valid_key).cuda()
            
            with torch.no_grad():
                cam_labels = _refine_cams(self.par, x[f], fake_mask, valid_key)
            
            cam_list.append(cam_labels)
        all_cam_labels = torch.stack(cam_list, dim=0)
        attn_mask = get_mask_by_radius(h=h//16, w=w//16, radius=8)
        fts_cam = all_cam_labels
        aff_label = cams_to_affinity_label(fts_cam, mask=attn_mask, ignore_index=255)
        attn_loss, pos_count, neg_count = get_aff_loss(attn_pred, aff_label)
        res = self.pfenet(x,s_x_pf,s_y,y_m,img_cam,clip_similarity)       
       



        # Loss
        if self.training:
            return res[0], res[1]+res[2] + 0.1*attn_loss, 0*res[1], 0*res[1]
        else:
            return res, res, res

    def get_optim(self, model, args, LR):
        param_groups = []
        for param in list(self.decoder_fts_fuse.parameters()):
            param_groups.append(param)
        optimizer = torch.optim.AdamW(
            [
                {'params': model.pfenet.parameters(), "lr": LR * 10},
                {'params': param_groups, "lr": LR * 10},
            ], lr=LR, weight_decay=args.weight_decay, betas=(0.9, 0.999))
        return optimizer

    def freeze_modules(self, model):
        for param in model.pfenet.layer0.parameters():
            param.requires_grad = False
        for param in model.pfenet.layer1.parameters():
            param.requires_grad = False
        for param in model.pfenet.layer2.parameters():
            param.requires_grad = False
        for param in model.pfenet.layer3.parameters():
            param.requires_grad = False
        for param in model.pfenet.layer4.parameters():
            param.requires_grad = False

    def extract_feats(self, x, mask=None):
        results = []
        with torch.no_grad():
            if mask is not None:
                tmp_mask = F.interpolate(mask, size=x.shape[-2], mode='nearest')
                x = x * tmp_mask
            feat = self.pfenet.layer0(x)
            results.append(feat)
            layers = [self.pfenet.layer1, self.pfenet.layer2, self.pfenet.layer3, self.pfenet.layer4]
            for _, layer in enumerate(layers):
                feat = layer(feat)
                results.append(feat.clone())
            results.append(feat)
        return results

def _refine_cams(ref_mod, images, cams, valid_key):
    images = images.unsqueeze(0)
    cams = cams.unsqueeze(0)

    refined_cams = ref_mod(images.float(), cams.float())
    refined_label = refined_cams.argmax(dim=1)
    refined_label = valid_key[refined_label]

    return refined_label.squeeze(0)

def get_mask_by_radius(h=20, w=20, radius=8):
    hw = h * w
    mask  = np.zeros((hw, hw))
    for i in range(hw):
        _h = i // w
        _w = i % w

        _h0 = max(0, _h - radius)
        _h1 = min(h, _h + radius+1)
        _w0 = max(0, _w - radius)
        _w1 = min(w, _w + radius+1)
        for i1 in range(_h0, _h1):
            for i2 in range(_w0, _w1):
                _i2 = i1 * w + i2
                mask[i, _i2] = 1
                mask[_i2, i] = 1

    return mask

def cams_to_affinity_label(cam_label, mask=None, ignore_index=255):
    
    b,h,w = cam_label.shape
    h,w = 473,473
    cam_label_resized = F.interpolate(cam_label.unsqueeze(1).type(torch.float32), size=[h//16, w//16], mode="nearest")

    # cam_label_resized = F.interpolate(cam_label.unsqueeze(1).type(torch.float32), size=[h//8, w//8], mode="nearest")
    _cam_label = cam_label_resized.reshape(b, 1, -1)
    _cam_label_rep = _cam_label.repeat([1, _cam_label.shape[-1], 1])
    _cam_label_rep_t = _cam_label_rep.permute(0,2,1)
    aff_label = (_cam_label_rep == _cam_label_rep_t).type(torch.long)
    #aff_label[(_cam_label_rep+_cam_label_rep_t) == 0] = ignore_index
    for i in range(b):

        if mask is not None:
            aff_label[i, mask==0] = ignore_index

        aff_label[i, :, _cam_label_rep[i, 0, :]==ignore_index] = ignore_index
        aff_label[i, _cam_label_rep[i, 0, :]==ignore_index, :] = ignore_index

    return aff_label


def get_aff_loss(inputs, targets):

    pos_label = (targets == 1).type(torch.int16)
    pos_count = pos_label.sum() + 1
    neg_label = (targets == 0).type(torch.int16)
    neg_count = neg_label.sum() + 1
    #inputs = torch.sigmoid(input=inputs)

    pos_loss = torch.sum(pos_label * (1 - inputs)) / pos_count
    neg_loss = torch.sum(neg_label * (inputs)) / neg_count

    return 0.5 * pos_loss + 0.5 * neg_loss, pos_count, neg_count